#ifndef _LIBSSH2_CPP_H_
#define _LIBSSH2_CPP_H_

#include <string>
#include <mutex>
#include <thread>
#include <functional>

#include <stdint.h>
#include <libssh2.h>

#include <time.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include<unistd.h>

namespace libssh2_cpp {

enum ErrorCode{
    noError                 = 0,
    statusError,
    initError,
    socketError,
    connectError,
    initSessionError,
    handShakeError,
    authMethodError,
    authMethodNotAllowError,
    authError,
    needPubkeyError,
    openSessionError,
    requestPtyError,
    openShellError,
    allocMemError,
};

enum Status {
    initStatus      = 1,
    connectedStatus,
    authedStatus,
    openSessionStatus,
    runCmdStatus,
};

class SSH2LibInit
{
private:
    std::mutex lib_mutex_;
    uint32_t   inst_count_;

    explicit SSH2LibInit() : inst_count_(0) {};
public:
    static SSH2LibInit* get_inst() {
        static SSH2LibInit object_;
        return &object_;
    }
    ErrorCode init() {
        std::lock_guard<std::mutex> lock_(lib_mutex_);
        if (0 == inst_count_) {
            // 初始化libssh2库
            if (0 != libssh2_init(0)) {
                return initError;
            }
            inst_count_++;
        }
        return noError;
    }
    void exit() {
        std::lock_guard<std::mutex> lock_(lib_mutex_);
        if (inst_count_ > 0) {
            inst_count_ --;
            if (inst_count_ == 0) {
                libssh2_exit();
            }
        }
    }
};

inline int _wait_socket_ready(int socket_, LIBSSH2_SESSION* session_, uint32_t timeout_second = 10) {
    fd_set fd;
    struct timeval timeout;
    fd_set *writefd = NULL;
    fd_set *readfd = NULL;

    timeout.tv_sec = timeout_second;
    timeout.tv_usec = 0;

    FD_ZERO(&fd);
    FD_SET(socket_, &fd);

    /* now make sure we wait in the correct direction */
    int dir = libssh2_session_block_directions(session_);

    if(dir & LIBSSH2_SESSION_BLOCK_INBOUND)
        readfd = &fd;

    if(dir & LIBSSH2_SESSION_BLOCK_OUTBOUND)
        writefd = &fd;

    printf("_wait_socket_ready [%d].\n", socket_);

    return select(socket_ + 1, readfd, writefd, NULL, &timeout);
}

inline ErrorCode _check_auth_methods(LIBSSH2_SESSION* session_, const std::string &username, uint32_t& auth_method){
    int ret = 0;
    auth_method = 0;
    /* check what authentication methods are available */
    char * userauthlist = NULL;
    do {
        userauthlist = libssh2_userauth_list(session_, username.c_str(), username.size());
        if (NULL == userauthlist) {
            ret = libssh2_session_last_errno(session_);
            char* err_msg = NULL;
            libssh2_session_last_error(session_, &err_msg, nullptr, 0);
            printf("error: [%s].\n", err_msg);
        }
    } while(ret == LIBSSH2_ERROR_EAGAIN);

    if (NULL == userauthlist) {
        return authMethodError;
    }
    // fprintf(stderr, "Authentication methods: %s\n", userauthlist);
    if(strstr(userauthlist, "password") != NULL) {
        auth_method |= 1;
    }
    if(strstr(userauthlist, "keyboard-interactive") != NULL) {
        auth_method |= 2;
    }
    if(strstr(userauthlist, "publickey") != NULL) {
        auth_method |= 4;
    }
    if (auth_method == 0) {
        return authMethodError;
    }
    return noError;
}

class Ssh2Session
{
protected:
    int                 sock_;
    LIBSSH2_SESSION *   session_;

private:
    uint32_t            session_status_;

public:
    inline LIBSSH2_SESSION* get_session()           {return session_;};
    inline uint32_t  get_status()const              {return session_status_;};
    inline void      set_status(uint32_t status_)   {session_status_ = status_;};

public:
    explicit Ssh2Session(): sock_(-1), session_(nullptr), session_status_(0) {
    };
    ~Ssh2Session() {
        disconnect();
    };
    // 添加移动构造函数
    Ssh2Session(Ssh2Session &&d) {
        sock_           = d.sock_;
        session_        = d.session_;
        session_status_ = d.session_status_;

        d.sock_             = -1;
        d.session_          = nullptr;
        d.session_status_   = 0;
    };
    ErrorCode init_libssh2() {
        return SSH2LibInit::get_inst()->init();
    }
    ErrorCode connect_ssh2(const char* target_ip, uint16_t port = 22) {
        int ret = 0;
        struct sockaddr_in sin;

        // 检查初始状态
        if (0 != get_status()) {
            return statusError;
        }

        set_status(initStatus);

        /* Ultra basic "connect to port 22 on localhost".  Your code is
        * responsible for creating the socket establishing the connection
        */
        sock_ = socket(AF_INET, SOCK_STREAM, 0);
        if (sock_ < 0) {
            return socketError;
        }

        memset(&sin, 0, sizeof(sin));
        sin.sin_family = AF_INET;
        sin.sin_port = htons(port);
        sin.sin_addr.s_addr = inet_addr(target_ip);
        // 连接socket
        if (connect(sock_, (struct sockaddr *)(&sin), sizeof(struct sockaddr_in)) != 0)
        {
            return connectError;
        }
        // 建立一个session，将this指针传递进去
        session_ = libssh2_session_init_ex(nullptr, nullptr, nullptr, this);
        if(nullptr == session_) {
            return initSessionError;
        }

        // 用户认证需要设置是block
        libssh2_session_set_blocking(session_, 1);
        // 尝试握手
        do { ret = libssh2_session_handshake(session_, sock_); } while(ret == LIBSSH2_ERROR_EAGAIN);
        // 如果握手失败
        if(0 != ret) {
            return handShakeError;
        }
        // 打印服务器fingerprint
#if 0
        int len = 0;
        int type = 0;
        const char *fingerprint = libssh2_session_hostkey(session_, &len, &type);
        fprintf(stderr, "Fingerprint: ");
        for(int i = 0; i < 20; i++) {
            fprintf(stderr, "%02X ", (unsigned char)fingerprint[i]);
        }
        fprintf(stderr, "\n");
#endif

        set_status(connectedStatus);
        return noError;
    };
    void disconnect() {
        if (session_status_ > connectedStatus) {
            printf("disconnect session .\n");
            if (session_ != nullptr) {
                libssh2_session_disconnect(session_, "Destruct Shutdown");
                libssh2_session_free(session_);
            }
            session_ = nullptr;
            ::close(sock_);
            sock_ = -1;
            SSH2LibInit::get_inst()->exit();
            session_status_ = 0;
        }
    }
};

template<class Base>
class Ssh2UserAuth1
{
protected:
    std::string         password_;

    // 静态回调转换成部分函数回调
    static void __kbd_1_callback(const char *name, int name_len, const char *instruction, int instruction_len, int num_prompts,
                            const LIBSSH2_USERAUTH_KBDINT_PROMPT *prompts, LIBSSH2_USERAUTH_KBDINT_RESPONSE *responses, void **abstract){
        if ((nullptr != abstract) && (nullptr != *abstract)) {
            (*((Base**)abstract)) -> Base::_kbd_1_callback(name, name_len, instruction, instruction_len, num_prompts, prompts, responses);
        }
    }
    void _kbd_1_callback(const char *name, int name_len, const char *instruction, int instruction_len, int num_prompts,
                            const LIBSSH2_USERAUTH_KBDINT_PROMPT *prompts, LIBSSH2_USERAUTH_KBDINT_RESPONSE *responses) {
        (void)name;
        (void)name_len;
        (void)instruction;
        (void)instruction_len;
        if(num_prompts == 1) {
            responses[0].text = strdup(password_.c_str());
            responses[0].length = strlen(password_.c_str());
        }
        (void)prompts;
    }
public:
    ErrorCode user_auth1(const std::string &username, const std::string &password, const char* keyfile1 = nullptr, const char* keyfile2 = nullptr) {
        int ret = 0;
        uint32_t auth_method = 0;
        if ((connectedStatus != ((Base*)this)-> Base::get_status()) && (authedStatus != ((Base*)this)-> Base::get_status())) {
            return statusError;
        }
        ErrorCode ret2 = _check_auth_methods(((Base*)this)-> Base::get_session(), username, auth_method);
        if (noError != ret2) {return ret2;}

        if(auth_method & 1) {
            /* We could authenticate via password */
            do { ret = libssh2_userauth_password(((Base*)this)-> Base::get_session(), username.c_str(), password.c_str()); } while(ret == LIBSSH2_ERROR_EAGAIN);
            if(ret) {
                return authError;
            }
        } else if(auth_method & 2) {
            /* Or via keyboard-interactive */
            password_ = password;
            do { ret = libssh2_userauth_keyboard_interactive(((Base*)this)-> Base::get_session(), username.c_str(), &__kbd_1_callback); } while(ret == LIBSSH2_ERROR_EAGAIN);
            if(ret) {
                return authError;
            }
        } else if(auth_method & 4) {
            if ((keyfile1 == nullptr) || (keyfile2 == nullptr)) {
                return needPubkeyError;
            }
            do { ret = libssh2_userauth_publickey_fromfile(((Base*)this)-> Base::get_session(), username.c_str(), keyfile1, keyfile2, password.c_str()); } while(ret == LIBSSH2_ERROR_EAGAIN);
            if(ret) {
                return authError;
            }
        }
        else {
            return authMethodNotAllowError;
        }
        ((Base*)this)-> Base::set_status(authedStatus);
        return noError;
    };

    Ssh2UserAuth1() {};
    // 添加移动构造函数
    Ssh2UserAuth1(Ssh2UserAuth1 &&d) {
        password_ = std::move(d.password_);
    };
};

template<class Base>
class Ssh2UserAuth2
{
public:
    using password_handle  = std::function<bool(std::string& passwd)>;
    
protected:
    password_handle  handle_;

protected:
    // 静态回调转换成部分函数回调
    static void __kbd_2_callback(const char *name, int name_len, const char *instruction, int instruction_len, int num_prompts,
                            const LIBSSH2_USERAUTH_KBDINT_PROMPT *prompts, LIBSSH2_USERAUTH_KBDINT_RESPONSE *responses, void **abstract){
        if ((nullptr != abstract) && (nullptr != *abstract)) {
            (*((Base**)abstract)) -> Base::_kbd_2_callback(name, name_len, instruction, instruction_len, num_prompts, prompts, responses);
        }
    }
    void _kbd_2_callback(const char *name, int name_len, const char *instruction, int instruction_len, int num_prompts,
                            const LIBSSH2_USERAUTH_KBDINT_PROMPT *prompts, LIBSSH2_USERAUTH_KBDINT_RESPONSE *responses)
    {
        std::string password;
        if (!handle_(password)) {
            return;
        }

        (void)name;
        (void)name_len;
        (void)instruction;
        (void)instruction_len;
        if(num_prompts == 1) {
            responses[0].text = strdup(password.c_str());
            responses[0].length = strlen(password.c_str());
        }
        (void)prompts;
    }
public:
    ErrorCode user_auth2(const std::string &username, password_handle handle) {
        int ret = 0;
        uint32_t auth_method = 0;
        if ((connectedStatus != ((Base*)this)-> Base::get_status()) && (authedStatus != ((Base*)this)-> Base::get_status())) {
            return statusError;
        }

        ErrorCode ret2 = _check_auth_methods(((Base*)this)-> Base::get_session(), username, auth_method);
        if (noError != ret2) {return ret2;}

        if(auth_method & 2) {
            /* Or via keyboard-interactive */
            handle_ = handle;
            do { ret = libssh2_userauth_keyboard_interactive(((Base*)this)-> Base::get_session(), username.c_str(), &__kbd_2_callback); } while(ret == LIBSSH2_ERROR_EAGAIN);
            if(ret) {
                printf("libssh2_userauth_password fail: %d.\n", ret);
                return authError;
            }
        } 
        else {
            return authMethodNotAllowError;
        }
        ((Base*)this)-> Base::set_status(authedStatus);
        return noError;
    };
    Ssh2UserAuth2() {};
    // 添加移动构造函数
    Ssh2UserAuth2(Ssh2UserAuth2 &&d) {
        handle_ = std::move(d.handle_);
    };
};

class Ssh2ChannelBase
{
protected:
    LIBSSH2_CHANNEL*   channel_;

    ErrorCode open_session(LIBSSH2_SESSION * session) {
        channel_ = libssh2_channel_open_session(session);
        if(!channel_) {
            if (0) {
                char* err_msg = NULL;
                libssh2_session_last_error(session, &err_msg, nullptr, 0);
                printf("error: [%s].\n", err_msg);
            }
            return openSessionError;
        }
        // 创建成功后，即flush一下
        libssh2_channel_flush(channel_);
        libssh2_channel_flush_stderr(channel_);
        return noError;
    }

public:
    Ssh2ChannelBase(): channel_(nullptr) {}
    ~Ssh2ChannelBase() {
        if (nullptr != channel_) {
            printf("close channel .\n");
            libssh2_channel_close(channel_);
            libssh2_channel_free(channel_);
        }
        channel_ = nullptr;
    }
    // 添加移动构造函数
    Ssh2ChannelBase(Ssh2ChannelBase &&d) {
        channel_ = d.channel_;
        d.channel_  = nullptr;
    }
    inline int channel_eof() {
        return libssh2_channel_eof(channel_);
    }
    inline ssize_t channel_read_ex(int stream_id, char *buf, size_t buflen) {
        //printf("channel_read_ex %d .\n", stream_id);
        return libssh2_channel_read_ex(channel_, stream_id, buf, buflen);
    }
    inline ssize_t channel_write_ex(int stream_id, const char *buf, size_t buflen) {
        //printf("channel_write_ex %d .\n", stream_id);
        return libssh2_channel_write_ex(channel_, stream_id, buf, buflen);
    }
};

// 线程成员函数所在类，禁止移动、拷贝
class Ssh2ChannelData: public Ssh2ChannelBase
{
public:
    using channel_data_handler  = std::function<void(uint8_t* , uint32_t)>;
    using channel_eof_handler   = std::function<void()>;
    using channel_wait_handler  = std::function<int(uint32_t)>;

private:
    channel_data_handler    dataout_handler_;
    channel_eof_handler     eof_handler_;
    std::thread             main_thread_;

    volatile bool           is_quit_;
    volatile bool           is_start_;

    // 循环读取并写到回调中
    void _read_data_func(channel_wait_handler wait_handler, channel_eof_handler& eof_handle, channel_data_handler& data_handle) {
        struct timespec req;
        memset(&req, 0, sizeof(req));
        req.tv_nsec = 1*1000;
        // 等待触发
        while (true){
            if (is_start_) {
                break;
            }
            if (is_quit_) {
                printf("read data exit.\n");
                return;
            }
            nanosleep(&req, nullptr);
        }

        int ret;
        req.tv_nsec = 10*1000;
        /* loop until we block */
        printf("read data loop.\n");
        do {
            if (is_quit_) {
                printf("read data exit.\n");
                return;
            }
            if (channel_eof()) {
                printf("channel eof, read data exit.\n");
                if (eof_handle)
                    eof_handle();
                return;
            }
            uint8_t buffer[0x4000];
            memset(buffer, 0, sizeof(buffer));
            ret = channel_read_ex(0, (char*)buffer, sizeof(buffer));
            if (is_quit_) {
                printf("read data exit.\n");
                return;
            }
            // 有数据
            if(ret > 0) {
                if (data_handle)
                    data_handle(buffer, ret);
            } else if (ret == 0) { // 没有数据
                nanosleep(&req, nullptr);
            } else { // 出错
                if(ret == LIBSSH2_ERROR_EAGAIN) {
                    //printf("enter wait.\n");
                    wait_handler(10);
                } else {
                    /* no need to output this for the EAGAIN case */
                    // fprintf(stderr, "libssh2_channel_read returned %d\n", ret);
                    nanosleep(&req, nullptr);
                }
            }
        } while(true);
        printf("read data exit.\n");
    };
public:
    void term_loop() {
        is_quit_ = true;
        // 线程join
        if (main_thread_.joinable()) {
            main_thread_.join();
        }
    };
    ErrorCode start_loop(channel_wait_handler handler) {
        // 创建两个监听输出线程、一个线程退出线程
        is_quit_ = false;
        is_start_ = false;
        main_thread_ = std::move(std::thread(std::bind(&Ssh2ChannelData::_read_data_func, this, handler, eof_handler_, dataout_handler_)));
        return noError;
    }
    void trigger_start() {
        is_start_ = true;
    }
    void ChannelEof(channel_eof_handler handle_) {
        eof_handler_ = handle_;
    }
    void ChannelRead(channel_data_handler handle_) {
        dataout_handler_ = handle_;
    };

private:
    // 禁止复制、移动
    Ssh2ChannelData(Ssh2ChannelData&) = delete;
    Ssh2ChannelData(const Ssh2ChannelData&) = delete;
    Ssh2ChannelData(const Ssh2ChannelData&&) = delete;
    Ssh2ChannelData(Ssh2ChannelData&& __d) = delete;

    Ssh2ChannelData &operator=(Ssh2ChannelData&) = delete;
    Ssh2ChannelData &operator=(const Ssh2ChannelData&) = delete;
    Ssh2ChannelData &operator=(Ssh2ChannelData&&) = delete;
    Ssh2ChannelData &operator=(const Ssh2ChannelData&&) = delete;

public:
    Ssh2ChannelData(): is_quit_(false), is_start_(false) {
    }
    ~Ssh2ChannelData() {
    }
};

// 依赖于libssh2
class Ssh2ShellBase: public Ssh2Session, public Ssh2ChannelData
{
public:
    explicit Ssh2ShellBase() : Ssh2Session(), Ssh2ChannelData() {
    };
    ~Ssh2ShellBase() {
        term_shell();
    };
    ErrorCode open_shell(int width = LIBSSH2_TERM_WIDTH, int height = LIBSSH2_TERM_HEIGHT) {
        if (authedStatus != get_status()) {
            return statusError;
        }
        ErrorCode ret = open_session(session_);
        if(noError != ret) {
            return ret;
        }
        /* Request a terminal with 'vanilla' terminal emulation
        * See /etc/termcap for more options
        */
        if(libssh2_channel_request_pty_ex(channel_, "xterm", sizeof("xterm"), NULL, 0, width, height, LIBSSH2_TERM_WIDTH_PX, LIBSSH2_TERM_HEIGHT_PX)) {
            return requestPtyError;
        }
        /* Open a SHELL on that pty */
        if(libssh2_channel_shell(channel_)) {
            return openShellError;
        }
        set_status(openSessionStatus);
        return noError;
    }
    ErrorCode run_shell() {
        if (openSessionStatus != get_status()) {
            return statusError;
        }
        // 打开Shell成功后，设置为non-block
        libssh2_session_set_blocking(session_, 0);

        ErrorCode ret = start_loop(std::bind(&_wait_socket_ready, sock_, session_, std::placeholders::_1));
        if (noError == ret) {
            set_status(runCmdStatus);
        }
        return ret;
    }
    void term_shell() {
        if (get_status() >= openSessionStatus) {
            term_loop();
        }
        set_status(authedStatus);
    }
};

class Ssh2Shell: public Ssh2ShellBase, public Ssh2UserAuth1<Ssh2Shell>, public Ssh2UserAuth2<Ssh2Shell>
{
public:
    Ssh2Shell(): Ssh2ShellBase(), Ssh2UserAuth1<Ssh2Shell>(), Ssh2UserAuth2<Ssh2Shell>() {};
    ~Ssh2Shell() {
        printf("destruction Ssh2Shell.\n");
    };
};

};

#endif
